{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "04206.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/yananma/5_programs_per_day/blob/master/04206.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ylH4KkeDCf71",
        "colab_type": "text"
      },
      "source": [
        "## 4.1 模型构造"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6prnAHrrDGBW",
        "colab_type": "text"
      },
      "source": [
        "### 4.1.1 继承 Module 类来构造模型"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "PYh3R27SGkrc",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import torch \n",
        "from torch import nn\n",
        "\n",
        "class MLP(nn.Module):\n",
        "    def __init__(self, **kwargs):\n",
        "        super(MLP, self).__init__(**kwargs)\n",
        "        self.hidden = nn.Linear(784, 256)\n",
        "        self.act = nn.ReLU()\n",
        "        self.output = nn.Linear(256, 10)\n",
        "\n",
        "    def forward(self, x):\n",
        "        a = self.act(self.hidden(x))\n",
        "        return self.output(a)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "blES5ahaHQNc",
        "colab_type": "code",
        "outputId": "9d394817-3857-4850-85c8-1a26edb0ad45",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 101
        }
      },
      "source": [
        "net = MLP()\n",
        "net"
      ],
      "execution_count": 3,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "MLP(\n",
              "  (hidden): Linear(in_features=784, out_features=256, bias=True)\n",
              "  (act): ReLU()\n",
              "  (output): Linear(in_features=256, out_features=10, bias=True)\n",
              ")"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 3
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "r33gZw8GHdnR",
        "colab_type": "code",
        "outputId": "1ae517d8-1d37-4af2-d9ba-837865bf5a41",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 84
        }
      },
      "source": [
        "X = torch.rand(2, 784)\n",
        "net(X)"
      ],
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "tensor([[ 0.0459,  0.1347, -0.3975,  0.0757, -0.1330, -0.0754,  0.0330,  0.1541,\n",
              "          0.0051,  0.1180],\n",
              "        [ 0.0429,  0.2412, -0.2639,  0.1379,  0.0189,  0.0011, -0.1205,  0.1794,\n",
              "          0.0137,  0.2713]], grad_fn=<AddmmBackward>)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 4
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9C_jzdE7HmUs",
        "colab_type": "text"
      },
      "source": [
        "### 4.1.2 Module 的子类"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1DFyHS0KHwxa",
        "colab_type": "text"
      },
      "source": [
        "#### 1. Sequential 类"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_oZ-4pV3XR1V",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "from collections import OrderedDict \n",
        "\n",
        "class MySequential(nn.Module):\n",
        "    def __init__(self, *args):\n",
        "        super(MySequential, self).__init__()\n",
        "        if len(args) == 1 and isinstance(args[0], OrderedDict):\n",
        "            for key, module in args[0].items():\n",
        "                self.add_module(key, module)\n",
        "        else:\n",
        "            for idx, module in enumerate(args):\n",
        "                self.add_module(str(idx), module)\n",
        "\n",
        "    \n",
        "    def forward(self, input):\n",
        "        for module in self._modules.values():\n",
        "            input = module(input)\n",
        "        return input "
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "teBev77GYrOW",
        "colab_type": "code",
        "outputId": "c4ab5af4-841b-46f5-aaae-fa3c172cd6e0",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 101
        }
      },
      "source": [
        "net = MySequential(\n",
        "    nn.Linear(784, 256), \n",
        "    nn.ReLU(), \n",
        "    nn.Linear(256, 10), \n",
        ")\n",
        "net "
      ],
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "MySequential(\n",
              "  (0): Linear(in_features=784, out_features=256, bias=True)\n",
              "  (1): ReLU()\n",
              "  (2): Linear(in_features=256, out_features=10, bias=True)\n",
              ")"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 6
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "jzaCg92OZACi",
        "colab_type": "code",
        "outputId": "8744f4c7-447b-4fc7-9b94-270f8e78550c",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 84
        }
      },
      "source": [
        "net(X)"
      ],
      "execution_count": 7,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "tensor([[-0.1612,  0.0973, -0.0472,  0.0649, -0.0788,  0.0777, -0.1005,  0.0287,\n",
              "          0.2223,  0.0737],\n",
              "        [-0.1212,  0.1584, -0.0714,  0.1273, -0.0960,  0.0352, -0.1249,  0.0433,\n",
              "          0.1862, -0.1236]], grad_fn=<AddmmBackward>)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 7
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "amUC_x-PZCXa",
        "colab_type": "text"
      },
      "source": [
        "#### 2. ModuleList 类"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "MnjXyru7ZI4w",
        "colab_type": "code",
        "outputId": "95b7f090-5322-402f-dc3c-ef8b18bc4bc4",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 101
        }
      },
      "source": [
        "net = nn.ModuleList([nn.Linear(784, 256), nn.ReLU()])\n",
        "net.append(nn.Linear(256, 10))\n",
        "net"
      ],
      "execution_count": 8,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "ModuleList(\n",
              "  (0): Linear(in_features=784, out_features=256, bias=True)\n",
              "  (1): ReLU()\n",
              "  (2): Linear(in_features=256, out_features=10, bias=True)\n",
              ")"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 8
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "IcbUTiC1ZglE",
        "colab_type": "code",
        "outputId": "2067511a-0e24-4f97-c07c-9d6f48450e51",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "net[-1]"
      ],
      "execution_count": 9,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "Linear(in_features=256, out_features=10, bias=True)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 9
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4DiEheauZi6Z",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class Module_ModuleList(nn.Module):\n",
        "    def __init__(self):\n",
        "        super(Module_ModuleList, self).__init__()\n",
        "        self.linears = nn.ModuleList([nn.Linear(10, 10)])\n",
        "\n",
        "\n",
        "class Module_List(nn.Module):\n",
        "    def __init__(self):\n",
        "        super(Module_List, self).__init__()\n",
        "        self.linears = [nn.Linear(10, 10)] "
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "PVOix16uMdEw",
        "colab_type": "code",
        "outputId": "d004b739-0298-4b5b-bf5a-11e300994acc",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 134
        }
      },
      "source": [
        "net1 = Module_ModuleList()\n",
        "\n",
        "print(net1)\n",
        "\n",
        "for p in net1.parameters():\n",
        "    print(p.size())"
      ],
      "execution_count": 11,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Module_ModuleList(\n",
            "  (linears): ModuleList(\n",
            "    (0): Linear(in_features=10, out_features=10, bias=True)\n",
            "  )\n",
            ")\n",
            "torch.Size([10, 10])\n",
            "torch.Size([10])\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "hyaIcAgiMxPo",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "7dd1ab14-07f5-4e65-86b0-e8261814dbbc"
      },
      "source": [
        "net2 = Module_List()\n",
        "\n",
        "print(net2)\n",
        "\n",
        "for p in net2.parameters():\n",
        "    print(p.size())"
      ],
      "execution_count": 12,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Module_List()\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bpJXT5h9NB7J",
        "colab_type": "text"
      },
      "source": [
        "#### 3. ModuleDict 类"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "hIg7DrjBTlft",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "3867c4c4-bdd1-449c-dd51-cc8cc4875c3d"
      },
      "source": [
        "net = nn.ModuleDict({\n",
        "    'linear': nn.Linear(784, 256), \n",
        "    'act': nn.ReLU(), \n",
        "})\n",
        "net['output'] = nn.Linear(256, 10)\n",
        "net['linear']"
      ],
      "execution_count": 13,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "Linear(in_features=784, out_features=256, bias=True)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 13
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "V1Imti7EkCi5",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "b4653780-2ec3-41cb-8d39-f605d1920984"
      },
      "source": [
        "net.output"
      ],
      "execution_count": 14,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "Linear(in_features=256, out_features=10, bias=True)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 14
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "uNXvQ7xPkGRo",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 101
        },
        "outputId": "7239eeb8-c54e-481a-c1d3-bf9d297f21a1"
      },
      "source": [
        "net"
      ],
      "execution_count": 15,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "ModuleDict(\n",
              "  (act): ReLU()\n",
              "  (linear): Linear(in_features=784, out_features=256, bias=True)\n",
              "  (output): Linear(in_features=256, out_features=10, bias=True)\n",
              ")"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 15
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fFYsrbdwkHRw",
        "colab_type": "text"
      },
      "source": [
        "### 4.1.3 构造复杂的模型"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "DduSoM2vkOJP",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class FancyMLP(nn.Module):\n",
        "    def __init__(self, **kwargs):\n",
        "        super(FancyMLP, self).__init__(**kwargs)\n",
        "\n",
        "        self.rand_weight = torch.rand((20, 20), requires_grad=False)\n",
        "        self.linear = nn.Linear(20, 20)\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = self.linear(x)\n",
        "        x = nn.functional.relu(torch.mm(x, self.rand_weight.data) + 1)\n",
        "        x = self.linear(x)\n",
        "        while x.norm().item() > 1:\n",
        "            x /= 2 \n",
        "        if x.norm().item() < 0.8:\n",
        "            x *= 10 \n",
        "        return x.sum()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "jkBFzM1KlK9P",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 67
        },
        "outputId": "4dee8bfa-e41d-431e-cb0c-ef27cdd43e7a"
      },
      "source": [
        "net = FancyMLP()\n",
        "net "
      ],
      "execution_count": 17,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "FancyMLP(\n",
              "  (linear): Linear(in_features=20, out_features=20, bias=True)\n",
              ")"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 17
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "E7OXUhIBm6Nt",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "9483a0cc-b229-4647-bdbc-7bd4ebe3f3bc"
      },
      "source": [
        "X = torch.rand(2, 20)\n",
        "net(X)"
      ],
      "execution_count": 18,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "tensor(-13.4920, grad_fn=<SumBackward0>)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 18
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ZQYBXxBlm937",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class NestMLP(nn.Module):\n",
        "    def __init__(self, **kwargs):\n",
        "        super(NestMLP, self).__init__(**kwargs)\n",
        "        self.net = nn.Sequential(nn.Linear(40, 30), nn.ReLU())\n",
        "\n",
        "    def forward(self, x):\n",
        "        return self.net(x)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "JjkRf0Yzncq6",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 218
        },
        "outputId": "1317c530-9a95-4f07-8835-f850ec78281a"
      },
      "source": [
        "net = nn.Sequential(NestMLP(), nn.Linear(30, 20), FancyMLP())\n",
        "net"
      ],
      "execution_count": 23,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "Sequential(\n",
              "  (0): NestMLP(\n",
              "    (net): Sequential(\n",
              "      (0): Linear(in_features=40, out_features=30, bias=True)\n",
              "      (1): ReLU()\n",
              "    )\n",
              "  )\n",
              "  (1): Linear(in_features=30, out_features=20, bias=True)\n",
              "  (2): FancyMLP(\n",
              "    (linear): Linear(in_features=20, out_features=20, bias=True)\n",
              "  )\n",
              ")"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 23
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "z5oVrQI8nk-_",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "f5e3d2e8-4cdc-483b-fbe6-1db22c382186"
      },
      "source": [
        "X = torch.rand(2, 40)\n",
        "net(X)"
      ],
      "execution_count": 24,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "tensor(10.6316, grad_fn=<SumBackward0>)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 24
        }
      ]
    }
  ]
}